#!/usr/bin/python
# full assembly of the sub-parts to form the complete net

from utils.saveNet import *
import torch.nn as nn

params = Parameters()
from utils.cmplxBatchNorm import normalizeComplexBatch_byMagnitudeOnly
import unet.unet_complex_parts as unet_cmplx

class CUNet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(CUNet, self).__init__()
        self.conv_3D_1 = unet_cmplx.conv_3D(in_ch, 32, kernel_size=(3,3,3), dilation=1, apply_BN=False, apply_activation=True)
        self.conv_3D_2 = unet_cmplx.conv_3D(32, in_ch, kernel_size=(3, 3, 3), dilation=1, apply_BN=False, apply_activation=True)

        self.inc = unet_cmplx.inconv(in_ch * params.moving_window_size, 32)
        self.down1 = unet_cmplx.down(32, 64)
        self.down2 = unet_cmplx.down(64, 128)
        self.bottleneck = unet_cmplx.bottleneck(128, 128, False)
        self.up2 = unet_cmplx.up(128, 64)
        self.up3 = unet_cmplx.up(64, 32)
        self.up4 = unet_cmplx.up(32, 32)
        self.ouc = unet_cmplx.outconv(32, out_ch)

    def forward(self, x):
        x = self.conv_3D_1(x)
        x = self.conv_3D_2(x)

        shape = x.shape
        x = x.reshape((shape[0], shape[1] * shape[2], shape[3], shape[4], shape[5]))
        x1 = self.inc(x)
        x2, down_x1 = self.down1(x1)
        x3, down_x2 = self.down2(x2)
        x4 = self.bottleneck(x3)
        x = self.up2(x4, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1) # + x0
        x = self.ouc(x)
        return x

class Net(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()

        # K-space Subnetwork
        self.stc3D_k1 = unet_cmplx.stacked_3Dconvs_block(in_channels, in_channels, kernel_size=(3,5,5), dilations=1,
                                                      apply_BN=False, apply_activation=False)
        self.stc3D_k2 = unet_cmplx.stacked_3Dconvs_block(in_channels, in_channels, kernel_size=(3,5,5), dilations=1,
                                                      apply_BN=False, apply_activation=False)

        # U-net
        self.unet_img = CUNet(in_channels, out_channels)


    def forward(self, x, Loc_xy=None):
        # K-space Subnetwork
        x = self.stc3D_k1(x) + x
        x = self.stc3D_k2(x) + x

        # Fourier Transform
        x = torch.ifft(x, 2, normalized=True)
        x = normalizeComplexBatch_byMagnitudeOnly(x, normalize_over_channel=True)

        #Image-domain Subnetwork
        x = self.unet_img(x)

        return x


